import websocket
import threading
import logging
import queue
import json
import time

from .microphone import iterate_microphone
from .config import Config
from .mixer import MixerControl

sample_rate = 16000

class SpeechRecognizer:
    def send_start_params(self, ws: websocket.WebSocket, dev_pid: int = 1537):
        body = json.dumps({
            "type": "START",
            "data": {
                "appid": Config.auth_app_id,
                "appkey": Config.auth_api_key,
                "dev_pid": dev_pid,
                "cuid": Config.auth_cuid,
                "sample": sample_rate,
                "format": "pcm",
            },
        })
        ws.send(body, websocket.ABNF.OPCODE_TEXT)

    def send_audio(self, ws: websocket.WebSocket, pcm: bytes):
        ws.send(pcm, websocket.ABNF.OPCODE_BINARY)

    def send_finish(self, ws: websocket.WebSocket):
        body = json.dumps({
            "type": "FINISH",
        })
        ws.send(body, websocket.ABNF.OPCODE_TEXT)

    def send_cancel(self, ws: websocket.WebSocket):
        body = json.dumps({
            "type": "CANCEL",
        })
        ws.send(body, websocket.ABNF.OPCODE_TEXT)

    def sender_thread(self, ws: websocket.WebSocket):
        self.send_start_params(ws)
        debug_audio = b''
        while True:
            pcm, finish = self.pcm_queue.get()
            if len(pcm) != 0:
                self.send_audio(ws, pcm)
            if self.is_debug:
                debug_audio += pcm
            if finish:
                break
        self.send_finish(ws)
        self.logger.debug("sender thread exit")
        if self.is_debug:
            import wave
            with wave.open("/tmp/output.wav", "wb") as f:
                f.setnchannels(1)
                f.setsampwidth(2)
                f.setframerate(sample_rate)
                f.writeframes(debug_audio)

    def mixer_thread(self):
        while True:
            try:
                command = self.mixer_queue.get()
            except queue.Empty:
                continue
            if command is None:
                return
            if command is True:
                self.is_loud = self.mixer.any_playing()
                if not self.is_loud:
                    continue
                with self.mixer.music_supressor(fade=2):
                    while True:
                        while True:
                            try:
                                command = self.mixer_queue.get(timeout=1.5)
                            except queue.Empty:
                                break
                            if command is None:
                                return
                            if command is False:
                                break
                        try:
                            command = self.mixer_queue.get(timeout=1.25)
                        except queue.Empty:
                            break
                        if command is None:
                            return
                        if command is False:
                            break
                time.sleep(1)

    def recorder_thread(self):
        n = 0
        started = False
        self.logger.debug("recorder thread started")
        options: dict = Config.mic_options
        for audio, finish in iterate_microphone(**options):
            if not started:
                self.start_thread(self.app_thread)
                started = True
            if finish:
                started = False
            m = len(audio)
            pcm = audio[n:]
            self.logger.info("PCM length: {}".format(len(pcm) // 2))
            # chunk_size = round(sample_rate * 0.16) * 2
            # if len(pcm) > chunk_size:
            #     chunks = list(range(0, len(pcm) - chunk_size // 2, chunk_size))
            #     for i, c in enumerate(chunks):
            #         if i == len(chunks) - 1:
            #             pcm_chunk = pcm[c:]
            #         else:
            #             pcm_chunk = pcm[c:c + chunk_size]
            #         self.logger.info("PCM chunk length: {}".format(len(pcm_chunk) // 2))
            #         self.pcm_queue.put((pcm_chunk, finish if i + 1 == len(chunks) else False))
            # else:
            self.pcm_queue.put((pcm, finish))
            self.mixer_queue.put(True)
            n = m
            if finish:
                n = 0
        self.mixer_queue.put(None)
        self.logger.debug("recorder thread exit")

    def on_open(self, ws: websocket.WebSocket):
        self.logger.info("socket open")
        self.start_thread(self.sender_thread, ws)

    def on_message(self, _: websocket.WebSocket, message: str):
        self.logger.info("socket message: {}".format(message))
        data = json.loads(message)
        if data['type'] == 'HEARTBEAT':
            return
        if data['err_no'] != 0 and data['err_no'] != -3005:
            self.logger.error("API ERROR {}: {}".format(data['err_no'], data['err_msg']))
        if data['type'] not in ('MID_TEXT', 'FIN_TEXT'):
            return
        result = data['result']
        finish = data['type'] == 'FIN_TEXT'
        self.text_queue.put((result, finish))

    def on_error(self, _: websocket.WebSocket, error: str):
        self.logger.error("SOCKET ERROR: {}".format(error))

    def on_close(self, *_):
        self.logger.info("socket closed")

    def __init__(self):
        self.logger = logging.getLogger()
        self.pcm_queue = queue.Queue()
        self.text_queue = queue.Queue()
        self.mixer_queue = queue.Queue()
        self.mixer = MixerControl()
        self.is_loud = False
        self.is_debug = True
        self.start_thread(self.recorder_thread)
        self.start_thread(self.mixer_thread)

    def start_thread(self, target, *args, **kwargs):
        threading.Thread(target=target, args=args, kwargs=kwargs, daemon=True).start()

    def __iter__(self):
        while True:
            text, finish = self.text_queue.get()
            if self.is_loud:
                if Config.loud_prefix and text.startswith(Config.loud_prefix):
                    text = text[len(Config.loud_prefix):]
                if text.strip().strip('，。') and finish:
                    self.mixer_queue.put(False)
            yield text, finish

    def app_thread(self):
        self.logger.info("starting app")
        uri = "ws://vop.baidu.com/realtime_asr?sn=" + Config.auth_snid
        ws_app = websocket.WebSocketApp(uri,
                                        on_open=self.on_open,
                                        on_message=self.on_message,
                                        on_error=self.on_error,
                                        on_close=self.on_close)
        ws_app.run_forever()

if __name__ == "__main__":
    logging.basicConfig(format='[%(asctime)-15s] [%(funcName)s()][%(levelname)s] %(message)s')
    logging.getLogger().setLevel(logging.INFO)
    from .typer import TypeWriter
    tw = TypeWriter()
    sr = SpeechRecognizer()
    for text, finish in sr:
        tw.input(text, finish)
